{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "399f2fcf-9241-4910-a30d-6ca19880d0ad",
   "metadata": {},
   "source": [
    "## Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97e68340-0096-475e-8ed8-22f5d627e3ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from fairseq import utils, tasks\n",
    "from fairseq import checkpoint_utils\n",
    "from utils.eval_utils import eval_step\n",
    "from tasks.mm_tasks import ImageGenTask\n",
    "from models.unival import UnIVALModel\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "import time\n",
    "\n",
    "\n",
    "# turn on cuda if GPU is available\n",
    "use_cuda = torch.cuda.is_available()\n",
    "# use fp16 only when GPU is available\n",
    "use_fp16 = True if use_cuda else False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "719cef65-c00c-4c9c-90b2-e660b386c3d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function fairseq.tasks.register_task.<locals>.register_task_cls(cls)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Register caption task\n",
    "tasks.register_task('image_gen', ImageGenTask)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc9c1d7b-898b-4ac4-adf3-832891d9e4be",
   "metadata": {},
   "source": [
    "### Load model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "568bb6ea-eef9-4024-98e6-35e74b5ffeec",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'\"beam\"'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1127924/2667361285.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     41\u001b[0m models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n\u001b[1;32m     42\u001b[0m     \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit_paths\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m     \u001b[0marg_overrides\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m )\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/ofa_ours/fairseq/fairseq/checkpoint_utils.py\u001b[0m in \u001b[0;36mload_model_ensemble_and_task\u001b[0;34m(filenames, arg_overrides, task, strict, suffix, num_shards, state)\u001b[0m\n\u001b[1;32m    455\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    456\u001b[0m                 \u001b[0;31m# model parallel checkpoint or unsharded checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 457\u001b[0;31m                 \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    458\u001b[0m                 model.load_state_dict(\n\u001b[1;32m    459\u001b[0m                     \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"model\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstrict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_cfg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/unival/tasks/mm_tasks/image_gen.py\u001b[0m in \u001b[0;36mbuild_model\u001b[0;34m(self, cfg)\u001b[0m\n\u001b[1;32m    125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m         \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m         \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/unival/tasks/ofa_task.py\u001b[0m in \u001b[0;36mbuild_model\u001b[0;34m(self, cfg)\u001b[0m\n\u001b[1;32m    220\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfg\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mFairseqDataclass\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m         \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    223\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbpe\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'bert'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m             bpe_dict = {\n",
      "\u001b[0;32m~/ofa_ours/fairseq/fairseq/tasks/fairseq_task.py\u001b[0m in \u001b[0;36mbuild_model\u001b[0;34m(self, cfg)\u001b[0m\n\u001b[1;32m    318\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0mfairseq\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquantization_utils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    319\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m         \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    321\u001b[0m         \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquantization_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquantize_model_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    322\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/ofa_ours/fairseq/fairseq/models/__init__.py\u001b[0m in \u001b[0;36mbuild_model\u001b[0;34m(cfg, task)\u001b[0m\n\u001b[1;32m    100\u001b[0m         \u001b[0;34mf\"Could not infer model type from {cfg}. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m         \"Available models: {}\".format(\n\u001b[0;32m--> 102\u001b[0;31m             \u001b[0mMODEL_DATACLASS_REGISTRY\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m         )\n\u001b[1;32m    104\u001b[0m         \u001b[0;34m+\u001b[0m \u001b[0;34mf\" Requested model type: {model_type}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: '\"beam\"'"
     ]
    }
   ],
   "source": [
    "# Load pretrained ckpt & config\n",
    "clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'\n",
    "vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'\n",
    "vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'\n",
    "\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'\n",
    "\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
    "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'\n",
    "checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best1.pt'\n",
    "\n",
    "\n",
    "\n",
    "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n",
    "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n",
    "\n",
    "gen_images_path='results/image_gen/'\n",
    "\n",
    "overrides = {\"bpe_dir\": \"utils/BPE\",\n",
    "             \"eval_cider\": False,\n",
    "             \"beam\": 24,\n",
    "             \"max_len_b\": 1024,\n",
    "             \"max_len_a\": 0,\n",
    "             \"min_len\": 1024,\n",
    "             \"sampling_topk\": 256,\n",
    "             \"constraint_range\": \"50265,58457\",\n",
    "             \"clip_model_path\": clip_model_path,\n",
    "             \"vqgan_model_path\": vqgan_model_path,\n",
    "             \"vqgan_config_path\": vqgan_config_path,\n",
    "             \"seed\": 42,\n",
    "             \"video_model_path\": video_model_path, \n",
    "             \"resnet_model_path\": resnet_model_path,\n",
    "             \"gen_images_path\":gen_images_path,\n",
    "             \"patch_image_size\": 256,\n",
    "             \"temperature\": 1.5,\n",
    "            }\n",
    "\n",
    "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
    "    utils.split_paths(checkpoint_path),\n",
    "    arg_overrides=overrides\n",
    ")\n",
    "\n",
    "task.cfg.sampling_times = 2\n",
    "# Move models to GPU\n",
    "for model in models:\n",
    "    model.eval()\n",
    "    if use_fp16:\n",
    "        model.half()\n",
    "    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
    "        model.cuda()\n",
    "    model.prepare_for_inference_(cfg)\n",
    "\n",
    "# Initialize generator\n",
    "generator = task.build_generator(models, cfg.generation)\n",
    "\n",
    "# Text preprocess\n",
    "bos_item = torch.LongTensor([task.src_dict.bos()])\n",
    "eos_item = torch.LongTensor([task.src_dict.eos()])\n",
    "pad_idx = task.src_dict.pad()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4a45ec-bce1-495b-8033-3b574367b360",
   "metadata": {},
   "source": [
    "### Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9f2e7e32-c9a0-43b3-bf86-2419d9f7dfe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_text(text, length=None, append_bos=False, append_eos=False):\n",
    "    s = task.tgt_dict.encode_line(\n",
    "        line=task.bpe.encode(text),\n",
    "        add_if_not_exist=False,\n",
    "        append_eos=False\n",
    "    ).long()\n",
    "    if length is not None:\n",
    "        s = s[:length]\n",
    "    if append_bos:\n",
    "        s = torch.cat([bos_item, s])\n",
    "    if append_eos:\n",
    "        s = torch.cat([s, eos_item])\n",
    "    return s\n",
    "\n",
    "\n",
    "# Construct input for image generation task\n",
    "def construct_sample(query: str):\n",
    "    code_mask = torch.tensor([True])\n",
    "    src_text = encode_text(\" what is the complete image? caption: {}\".format(query), append_bos=True,\n",
    "                           append_eos=True).unsqueeze(0)\n",
    "    src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n",
    "    sample = {\n",
    "        \"id\": np.array(['42']),\n",
    "        \"net_input\": {\n",
    "            \"src_tokens\": src_text,\n",
    "            \"src_lengths\": src_length,\n",
    "            \"code_masks\": code_mask\n",
    "        }\n",
    "    }\n",
    "    return sample\n",
    "\n",
    "\n",
    "# Function to turn FP32 to FP16\n",
    "def apply_half(t):\n",
    "    if t.dtype is torch.float32:\n",
    "        return t.to(dtype=torch.half)\n",
    "    return t\n",
    "\n",
    "\n",
    "# Function for image generation\n",
    "def image_generation(caption):\n",
    "    sample = construct_sample(caption)\n",
    "    sample = utils.move_to_cuda(sample) if use_cuda else sample\n",
    "    sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n",
    "    print('|Start|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
    "    with torch.no_grad():\n",
    "        result, scores = eval_step(task, generator, models, sample)\n",
    "\n",
    "    # return top-4 results (ranked by clip)\n",
    "    images = [result[i]['image'] for i in range(4)]\n",
    "    pic_size = 256\n",
    "    retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))\n",
    "    print('|FINISHED|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
    "    for i in range(4):\n",
    "        loc = ((i % 2) * pic_size, int(i / 2) * pic_size)\n",
    "        retImage.paste(images[i], loc)\n",
    "    return retImage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44dec799-c5c2-4d22-8b08-7a7ca2cdf3c9",
   "metadata": {},
   "source": [
    "### Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "02d5cd7a-8d63-4fa4-9da1-d4b79ec01445",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|Start| 2023-06-29 12:57:39 A brown horse in the street\n",
      "|FINISHED| 2023-06-29 12:59:03 A brown horse in the street\n"
     ]
    }
   ],
   "source": [
    "query = \"A brown horse in the street\"\n",
    "# query = \"Cattle grazing on grass near a lake surrounded by mountain.\"\n",
    "# query = 'A street scene with a double-decker bus on the road.'\n",
    "# query = 'A path.'\n",
    "\n",
    "\n",
    "retImage = image_generation(query)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a8a1654-1f17-41c7-b410-c7491a96dcee",
   "metadata": {},
   "outputs": [],
   "source": [
    "retImage.save(f'{query}.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ofa",
   "language": "python",
   "name": "ofa"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
